/*
 * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights
 * reserved. SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <algorithm>
#include <cuopt/linear_programming/mip/solver_settings.hpp>
#include <cuopt/linear_programming/solve.hpp>
#include <mip/problem/problem.cuh>
#include <mps_parser/parser.hpp>
#include <utilities/copy_helpers.hpp>

namespace cuopt::linear_programming::test {

static void test_variable_bounds(
  const cuopt::mps_parser::mps_data_model_t<int, double>& problem,
  const rmm::device_uvector<double>& solution,
  const cuopt::linear_programming::mip_solver_settings_t<int, double> settings)
{
  const double* lower_bound_ptr = problem.get_variable_lower_bounds().data();
  const double* upper_bound_ptr = problem.get_variable_upper_bounds().data();
  auto host_assignment          = cuopt::host_copy(solution);
  double* assignment_ptr        = host_assignment.data();
  cuopt_assert(host_assignment.size() == problem.get_variable_lower_bounds().size(), "");
  cuopt_assert(host_assignment.size() == problem.get_variable_upper_bounds().size(), "");
  std::vector<int> indices(host_assignment.size());
  std::iota(indices.begin(), indices.end(), 0);
  bool result = std::all_of(indices.begin(), indices.end(), [=](int idx) {
    bool res = true;
    if (lower_bound_ptr != nullptr) {
      res = res && (assignment_ptr[idx] >=
                    lower_bound_ptr[idx] - settings.tolerances.integrality_tolerance);
    }
    if (upper_bound_ptr != nullptr) {
      res = res && (assignment_ptr[idx] <=
                    upper_bound_ptr[idx] + settings.tolerances.integrality_tolerance);
    }
    return res;
  });
  EXPECT_TRUE(result);
}

template <typename f_t>
static double combine_finite_abs_bounds(f_t lower, f_t upper)
{
  f_t val = f_t(0);
  if (isfinite(upper)) { val = raft::max<f_t>(val, raft::abs(upper)); }
  if (isfinite(lower)) { val = raft::max<f_t>(val, raft::abs(lower)); }
  return val;
}

template <typename f_t>
struct violation {
  violation() {}
  violation(f_t* _scalar) {}
  __device__ __host__ f_t operator()(f_t value, f_t lower, f_t upper)
  {
    if (value < lower) {
      return lower - value;
    } else if (value > upper) {
      return value - upper;
    }
    return f_t(0);
  }
};

static void test_constraint_sanity_per_row(
  const cuopt::mps_parser::mps_data_model_t<int, double>& op_problem,
  const rmm::device_uvector<double>& solution,
  double abs_tolerance,
  double rel_tolerance)
{
  const std::vector<double>& values                  = op_problem.get_constraint_matrix_values();
  const std::vector<int>& indices                    = op_problem.get_constraint_matrix_indices();
  const std::vector<int>& offsets                    = op_problem.get_constraint_matrix_offsets();
  const std::vector<double>& constraint_lower_bounds = op_problem.get_constraint_lower_bounds();
  const std::vector<double>& constraint_upper_bounds = op_problem.get_constraint_upper_bounds();
  const std::vector<double>& variable_lower_bounds   = op_problem.get_variable_lower_bounds();
  const std::vector<double>& variable_upper_bounds   = op_problem.get_variable_upper_bounds();
  std::vector<double> residual(constraint_lower_bounds.size(), 0.0);
  std::vector<double> viol(constraint_lower_bounds.size(), 0.0);
  auto h_solution = cuopt::host_copy(solution);
  // CSR SpMV
  for (size_t i = 0; i < offsets.size() - 1; ++i) {
    for (int j = offsets[i]; j < offsets[i + 1]; ++j) {
      residual[i] += values[j] * h_solution[indices[j]];
    }
  }

  auto functor = violation<double>{};

  // Compute violation to lower/upper bound
  for (size_t i = 0; i < residual.size(); ++i) {
    double tolerance = abs_tolerance + combine_finite_abs_bounds<double>(
                                         constraint_lower_bounds[i], constraint_upper_bounds[i]) *
                                         rel_tolerance;
    double viol = functor(residual[i], constraint_lower_bounds[i], constraint_upper_bounds[i]);
    EXPECT_LE(viol, tolerance);
  }
}

static std::tuple<mip_termination_status_t, double, double> test_mps_file(
  std::string test_instance,
  double time_limit    = 1,
  bool heuristics_only = true,
  bool presolve        = true)
{
  const raft::handle_t handle_{};

  auto path = make_path_absolute(test_instance);
  cuopt::mps_parser::mps_data_model_t<int, double> problem =
    cuopt::mps_parser::parse_mps<int, double>(path, false);
  handle_.sync_stream();
  mip_solver_settings_t<int, double> settings;
  settings.time_limit                  = time_limit;
  settings.heuristics_only             = heuristics_only;
  settings.presolve                    = presolve;
  mip_solution_t<int, double> solution = solve_mip(&handle_, problem, settings);
  return std::make_tuple(solution.get_termination_status(),
                         solution.get_objective_value(),
                         solution.get_solution_bound());
}

}  // namespace cuopt::linear_programming::test
